from collections.abc import Mapping, Sequence
from dataclasses import dataclass, field
from logging import getLogger
from typing import (
    Any,
    Callable,
    Literal,
    ParamSpec,
    Protocol,
    Type,
    Union,
    overload,
    runtime_checkable,
)

from pydantic import BaseModel, Field

from inspect_ai._util.dateutil import UtcDatetime, datetime_now_utc
from inspect_ai._util.error import PrerequisiteError
from inspect_ai._util.metadata import MT, metadata_as
from inspect_ai._util.registry import (
    RegistryInfo,
    is_registry_object,
    registry_add,
    registry_create,
    registry_info,
    registry_name,
    registry_params,
    registry_tag,
)

logger = getLogger(__name__)

CORRECT = "C"
"""Value to assign for correct answers."""

INCORRECT = "I"
"""Value to assign for incorrect answers."""

PARTIAL = "P"
"""Value to assign for partial credit."""

NOANSWER = "N"
"""Value to assign for no answer or refusal to answer."""


Value = Union[
    str | int | float | bool,
    Sequence[str | int | float | bool],
    Mapping[str, str | int | float | bool | None],
]
"""Value provided by a score.

Use the methods of `Score` to easily treat
the `Value` as a simple scalar of various types.
"""

UNCHANGED: Literal["UNCHANGED"] = "UNCHANGED"
"""Sentinel value to indicate an unchanged field in score edits."""


class ProvenanceData(BaseModel):
    """Metadata about who made an edit and why."""

    timestamp: UtcDatetime = Field(default_factory=datetime_now_utc)
    """Timestamp when the edit was made."""

    author: str
    """Author who made the edit."""

    reason: str | None = Field(default=None)
    """Reason for the edit."""

    metadata: dict[str, Any] = Field(default_factory=dict)
    """Additional metadata about the edit."""


class ScoreEdit(BaseModel):
    """A single edit to a score."""

    value: Value | Literal["UNCHANGED"] = "UNCHANGED"
    """New value for the score, or UNCHANGED to keep current value."""

    answer: str | None | Literal["UNCHANGED"] = "UNCHANGED"
    """New answer for the score, or UNCHANGED to keep current answer."""

    explanation: str | None | Literal["UNCHANGED"] = "UNCHANGED"
    """New explanation for the score, or UNCHANGED to keep current explanation."""

    metadata: dict[str, Any] | Literal["UNCHANGED"] = "UNCHANGED"
    """New metadata for the score, or UNCHANGED to keep current metadata."""

    provenance: ProvenanceData | None = None
    """Provenance data for this edit. None indicates this is the original score."""


class Score(BaseModel):
    """Score generated by a scorer."""

    value: Value
    """Score value."""

    answer: str | None = Field(default=None)
    """Answer extracted from model output (optional)"""

    explanation: str | None = Field(default=None)
    """Explanation of score (optional)."""

    metadata: dict[str, Any] | None = Field(default=None)
    """Additional metadata related to the score"""

    history: list[ScoreEdit] = Field(default_factory=list)
    """Edit history - users can access intermediate states."""

    @property
    def text(self) -> str:
        """Read the score as text."""
        return self.as_str()

    def as_str(self) -> str:
        """Read the score as a string."""
        return str(self._as_scalar())

    def as_int(self) -> int:
        """Read the score as an integer."""
        return int(self._as_scalar())

    def as_float(self) -> float:
        """Read the score as a float."""
        return float(self._as_scalar())

    def as_bool(self) -> bool:
        """Read the score as a boolean."""
        return bool(self._as_scalar())

    def as_list(self) -> list[str | int | float | bool]:
        """Read the score as a list."""
        if isinstance(self.value, list):
            return self.value
        else:
            raise ValueError("This score is not a list")

    def as_dict(self) -> dict[str, str | int | float | bool | None]:
        """Read the score as a dictionary."""
        if isinstance(self.value, dict):
            return self.value
        else:
            raise ValueError("This score is not a dictionary")

    def _as_scalar(self) -> str | int | float | bool:
        if isinstance(self.value, str | int | float | bool):
            return self.value
        else:
            raise ValueError("This score is not a scalar")


class SampleScore(BaseModel):
    """Score for a Sample."""

    score: Score
    """A score"""

    sample_id: str | int | None = Field(default=None)
    """A sample id"""

    sample_metadata: dict[str, Any] | None = Field(default=None)
    """Metadata from the sample"""

    def sample_metadata_as(self, metadata_cls: Type[MT]) -> MT | None:
        """Pydantic model interface to sample metadata.

        Args:
          metadata_cls: Pydantic model type

        Returns:
          BaseModel: Instance of metadata_cls bound to sample metadata.
        """
        if self.sample_metadata is not None:
            return metadata_as(self.sample_metadata, metadata_cls)
        else:
            return None

    scorer: str | None = Field(default=None)
    """Registry name of scorer that created this score."""


ValueToFloat = Callable[[Value], float]
"""Function used by metrics to translate from a Score value to a float value."""


def value_to_float(
    correct: Value = CORRECT,
    incorrect: Value = INCORRECT,
    partial: Value = PARTIAL,
    noanswer: Value = NOANSWER,
) -> ValueToFloat:
    """Create a ValueToFloat function.

    Create a ValueToFloat function that maps scalar values of
    different types into floats. For strings, common boolean
    representations (e.g. 'yes', 'no', 'true', 'false') are
    mapped to 1 and 0. In addition, the specified correct,
    incorrect, partial, and noanswer values (by default "C"
    "I", "P", and "N") are mapped to 1, 0, 0.5, and 0. Note that
    those are the default literal values, but they can be
    customized. Strings with only numbers are converted, and
    numeric values are cast to float. Arrays and dictionaries
    give a warning and return 0.

    Args:
       correct (Value): Value that represents a correct answer (1)
       incorrect (Value): Value that represents an incorrect answer (0)
       partial (Value): Value to assign partial credit for (0.5)
       noanswer (Value): Value for refusals to answer (0)

    Returns:
        ValueToFloat function.
    """

    def to_float(value: Value) -> float:
        if isinstance(value, int | float | bool):
            return float(value)
        elif value == correct:
            return 1.0
        elif value == partial:
            return 0.5
        elif value == incorrect or value == noanswer:
            return 0
        elif isinstance(value, str):
            value = value.lower()
            if value in ["yes", "true"]:
                return 1.0
            elif value in ["no", "false"]:
                return 0.0
            elif is_number(value):
                return float(value)

        # couldn't extract a value
        logger.warning(f"Unable to convert value to float: {value}")
        return 0.0

    return to_float


def is_number(s: str) -> bool:
    try:
        float(s)
        return True
    except ValueError:
        return False


@runtime_checkable
class MetricDeprecated(Protocol):
    def __call__(self, scores: list[Score]) -> Value: ...


@runtime_checkable
class MetricProtocol(Protocol):
    def __call__(self, scores: list[SampleScore]) -> Value:
        r"""Compute a metric on a list of scores.

        Args:
          scores: List of scores.

        Returns:
          Metric value

        Examples:
          ```python
          @metric
          def mean() -> Metric:
              def metric(scores: list[SampleScore]) -> Value:
                  return np.mean([score.score.as_float() for score in scores]).item()
              return metric
          ```
        """
        ...


Metric = MetricProtocol | MetricDeprecated
"""Metric protocol.

The Metric signature changed in release v0.3.64. Both
the previous and new signatures are supported -- you
should use `MetricProtocol` for new code as the
depreacated signature will eventually be removed.
"""


P = ParamSpec("P")


@dataclass(frozen=True)
class MetricSpec:
    """Scorer specification used to (re-)create scorers."""

    metric: str
    """Metric name"""

    args: dict[str, Any] = field(default_factory=dict)
    """Metric arguments."""


def metric_register(metric: Callable[P, Metric], name: str = "") -> Callable[P, Metric]:
    r"""Register a function or class as a metric.

    Args:
        metric (MetricType):
            Function that returns a Metric or class
            deriving fromMetric
        name (str): Name of metric (Optional, defaults to object name)

    Returns:
        Metric type with registry attributes.
    """
    metric_name = name if name else getattr(metric, "__name__")
    registry_add(metric, RegistryInfo(type="metric", name=metric_name))
    return metric


def metric_create(name: str, **kwargs: Any) -> Metric:
    r"""Create a Metric based on its registered name.

    Metrics can be functions that return a Metric or classes
    deriving from Metric

    Args:
        name (str): Name of metric (Optional, defaults to object name)
        **kwargs (dict): Optional creation arguments for the metric

    Returns:
        Metric with registry info attribute
    """
    return registry_create("metric", name, **kwargs)


def to_metric_specs(
    metrics: list[Metric | dict[str, list[Metric]]] | dict[str, list[Metric]],
) -> list[MetricSpec | dict[str, list[MetricSpec]]] | dict[str, list[MetricSpec]]:
    if isinstance(metrics, list):
        result: list[MetricSpec | dict[str, list[MetricSpec]]] = []
        for metric_item in metrics:
            if isinstance(metric_item, dict):
                # It's a dict of metric groups
                result.append(
                    {
                        k: [as_metric_spec(v) for v in metric_list]
                        for k, metric_list in metric_item.items()
                    }
                )
            else:
                # It's a direct metric
                result.append(as_metric_spec(metric_item))
        return result
    else:
        return {
            k: [as_metric_spec(v) for v in metric_list]
            for k, metric_list in metrics.items()
        }


def as_metric_spec(metric: Metric) -> MetricSpec:
    if not is_registry_object(metric):
        raise PrerequisiteError(
            f"The metric {getattr(metric, '__name__', '<unknown>')} was not created by a function decorated with @metric so cannot be recorded."
        )
    return MetricSpec(metric=registry_info(metric).name, args=registry_params(metric))


@overload
def metric(name: str) -> Callable[[Callable[P, Metric]], Callable[P, Metric]]: ...


@overload
# type: ignore
def metric(name: Callable[P, Metric]) -> Callable[P, Metric]: ...


def metric(
    name: str | Callable[P, Metric],
) -> Callable[[Callable[P, Metric]], Callable[P, Metric]] | Callable[P, Metric]:
    r"""Decorator for registering metrics.

    Args:
      name: Optional name for metric. If the decorator has no name
        argument then the name of the underlying MetricType
        will be used to automatically assign a name.

    Examples:
      ```python
      @metric
      def mean() -> Metric:
          def metric(scores: list[SampleScore]) -> Value:
              return np.mean([score.score.as_float() for score in scores]).item()
          return metric
    ```
    """

    # create_metric_wrapper:
    #  (a) Add the MetricType to the registry using the appropriately
    #      package-namespaced name
    #  (b) Ensure that instances of Metric created by MetricType also
    #      carry registry info.
    def create_metric_wrapper(
        metric_type: Callable[P, Metric], name: str | None = None
    ) -> Callable[P, Metric]:
        metric_name = registry_name(
            metric_type, name if name else getattr(metric_type, "__name__")
        )

        def metric_wrapper(*args: P.args, **kwargs: P.kwargs) -> Metric:
            metric = metric_type(*args, **kwargs)
            registry_tag(
                metric_type,
                metric,
                RegistryInfo(type="metric", name=metric_name),
                *args,
                **kwargs,
            )
            return metric

        return metric_register(metric_wrapper, metric_name)

    # for decorators with an explicit name, one more wrapper for the name
    if isinstance(name, str):

        def wrapper(metric_type: Callable[P, Metric]) -> Callable[P, Metric]:
            return create_metric_wrapper(metric_type, name)

        return wrapper

    # create a metric wrapper for the passed metric_type
    else:
        metric_type = name
        return create_metric_wrapper(metric_type)
